{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How-to Log Sum Exp\n",
    "\n",
    "Using flash-attention intensively you will at some point hear about `lse` values being returend. `lse` stands for \"log-sum-exp\" and can be used to compute softmax (and thereby also attention) in a blockwise and stable fashion. This notebook aims to explain how this works."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start by defining a naive softmax function .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def naive_softmax(x: torch.Tensor) -> torch.Tensor:\n",
    "    return x.exp() / x.exp().sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ".. and verify that its output matches the output of the official `torch.softmax()` function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "a tensor([0.0695, 0.0043, 0.0450, 0.0530, 0.0109, 0.0142, 0.0665, 0.6751, 0.0143,\n",
      "        0.0472])\n",
      "b tensor([0.0695, 0.0043, 0.0450, 0.0530, 0.0109, 0.0142, 0.0665, 0.6751, 0.0143,\n",
      "        0.0472])\n",
      "allclose True\n"
     ]
    }
   ],
   "source": [
    "x = torch.randn(10)  # generate normally distributed random numbers\n",
    "a = torch.softmax(x, dim=-1) # reference output\n",
    "b = naive_softmax(x) # our naive version\n",
    "\n",
    "print(\"a\", a)\n",
    "print(\"b\", b)\n",
    "print(\"allclose\", torch.allclose(a, b, atol=1e-6))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our naive softmax function has a problem with numerical stability when it gets input vectors with larger elements:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0., 0., 0., 0., 0., 0., 0., nan, 0., 0.])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "naive_softmax(x * 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before we discuss how to fix this let's first try to compute softmax in a blockwise fashion.\n",
    "Let's start by generating a random vector and split it into two chunks of equal size and compute softmax on these chunks individually."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "We have:\n",
      "s1 = tensor([0.7998, 0.0296, 0.1147, 0.0265, 0.0294])\n",
      "s2 = tensor([0.1985, 0.3296, 0.0460, 0.1054, 0.3205])\n",
      "We want:\n",
      "target = tensor([0.4914, 0.0182, 0.0705, 0.0163, 0.0181, 0.0766, 0.1271, 0.0177, 0.0406,\n",
      "        0.1236])\n"
     ]
    }
   ],
   "source": [
    "x = torch.randn(10)\n",
    "\n",
    "x1,x2 = torch.chunk(x, 2)\n",
    "s1 = naive_softmax(x1)\n",
    "s2 = naive_softmax(x2)\n",
    "\n",
    "print(\"We have:\")\n",
    "print(f\"s1 = {s1}\")\n",
    "print(f\"s2 = {s2}\")\n",
    "\n",
    "target = naive_softmax(x)\n",
    "print(\"We want:\")\n",
    "print(f\"target = {target}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we look at `naive_softmax()` we note that its output has been divided by `x.exp().sum()`. We can call this the \"sum exp\" value (note the similarity to \"log sum exp\") and we can use it to \"undo\" the softmax normalization and thereby compute combine multiple softmax chunks if we have this vaue for each chunk."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After correction with help of sum_exp values:\n",
      "s_combined tensor([0.4914, 0.0182, 0.0705, 0.0163, 0.0181, 0.0766, 0.1271, 0.0177, 0.0406,\n",
      "        0.1236])\n",
      "allclose(s_combined, target): True\n"
     ]
    }
   ],
   "source": [
    "sum_exp_x1 = x1.exp().sum()\n",
    "sum_exp_x2 = x2.exp().sum()\n",
    "s1_corrected = s1 * sum_exp_x1 / (sum_exp_x1 + sum_exp_x2)\n",
    "s2_corrected = s2 * sum_exp_x2 / (sum_exp_x1 + sum_exp_x2)\n",
    "\n",
    "print(\"After correction with help of sum_exp values:\")\n",
    "s_combined = torch.cat([s1_corrected, s2_corrected])\n",
    "print(\"s_combined\", s_combined)\n",
    "\n",
    "print(\"allclose(s_combined, target):\", torch.allclose(s_combined, target))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "... but is this helpful at all? Yes, and it becomes more obivous when we realize that we can return this value from our softmax function and we can do the correction in a blockwise fashion in a loop by accumulating the `sum_exp` value:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "a tensor([0.0608, 0.0217, 0.0446, 0.0180, 0.0289, 0.2553, 0.0101, 0.0150, 0.0699,\n",
      "        0.0231, 0.0672, 0.0388, 0.0186, 0.0126, 0.0731, 0.0132, 0.0191, 0.0362,\n",
      "        0.1688, 0.0050])\n",
      "b tensor([0.0608, 0.0217, 0.0446, 0.0180, 0.0289, 0.2553, 0.0101, 0.0150, 0.0699,\n",
      "        0.0231, 0.0672, 0.0388, 0.0186, 0.0126, 0.0731, 0.0132, 0.0191, 0.0362,\n",
      "        0.1688, 0.0050])\n",
      "allclose: True\n"
     ]
    }
   ],
   "source": [
    "from typing import Tuple, Sequence\n",
    "\n",
    "def naive_softmax2(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n",
    "    sum_exp = x.exp().sum()\n",
    "    return x.exp() / sum_exp, sum_exp\n",
    "\n",
    "\n",
    "def naive_blockwise_softmax(blocks: Sequence[torch.Tensor]) -> torch.Tensor:\n",
    "    total_sum_exp = 0\n",
    "    blocks_out = []\n",
    "    for block in blocks:\n",
    "        block_softmax, block_sum_exp = naive_softmax2(block)\n",
    "        blocks_out.append((block_softmax, block_sum_exp))\n",
    "        total_sum_exp += block_sum_exp\n",
    "\n",
    "    out = []\n",
    "    for block_softmax, block_sum_exp in blocks_out:\n",
    "        out.append(block_softmax * block_sum_exp / total_sum_exp)\n",
    "\n",
    "    return torch.cat(out)\n",
    "\n",
    "x_long = torch.randn(20)\n",
    "chunks = torch.chunk(x_long, 4)\n",
    "a = naive_blockwise_softmax(chunks)\n",
    "b = torch.softmax(x_long, dim=-1)\n",
    "print(\"a\", a)\n",
    "print(\"b\", b)\n",
    "print(\"allclose:\", torch.allclose(a, b))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "OK, then now let's look at the numerical stability of softmax. First we can observe a interesting property of the softmax function: its output is shift/translation invariant (i.e. `f(x+a)=f(x)`):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.0107, 0.0226, 0.0729, 0.0937, 0.0722, 0.2008, 0.1214, 0.4057])\n",
      "tensor([0.0107, 0.0226, 0.0729, 0.0937, 0.0722, 0.2008, 0.1214, 0.4057])\n",
      "tensor([0.0107, 0.0226, 0.0729, 0.0937, 0.0722, 0.2008, 0.1214, 0.4057])\n"
     ]
    }
   ],
   "source": [
    "x = torch.randn(8)\n",
    "print(naive_softmax(x))\n",
    "print(naive_softmax(x+5))\n",
    "print(naive_softmax(x-3))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This porperty allows us to deal with problematic large inputs simply by subtracting their maximum:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def stable_softmax(x):\n",
    "    m = x.max()\n",
    "    return (x-m).exp() / (x-m).exp().sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This \"stable\" function now can also deal with larger value that were problematic for our naive function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "naive:  tensor([0., nan, 0., 0., 0., 0., 0., 0., 0., 0.])\n",
      "stable:  tensor([4.6243e-44, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n",
      "        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00])\n",
      "torch:  tensor([4.6243e-44, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n",
      "        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00])\n"
     ]
    }
   ],
   "source": [
    "large_input = torch.randn(10) * 100\n",
    "\n",
    "print(\"naive: \", naive_softmax(large_input))\n",
    "print(\"stable: \", stable_softmax(large_input))\n",
    "print(\"torch: \", torch.softmax(large_input, dim=-1))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def stable_softmax2(x):\n",
    "    \"\"\"returns softmax result and log sum exp\"\"\"\n",
    "    m = x.max()\n",
    "    a = (x - m).exp()\n",
    "    b = a.sum()\n",
    "    lse = m + torch.log(b)\n",
    "    return a / b, lse"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Again we can now use this to combine two softmax block results, but to do it in the same way as before we would need to calculate the exp() values.. which is as we know numerically not stable:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.0856, 0.0951, 0.0390, 0.0288, 0.0738, 0.0110, 0.0129, 0.0293, 0.0190,\n",
      "        0.1449, 0.0112, 0.0428, 0.0573, 0.0605, 0.0056, 0.0085, 0.1530, 0.0500,\n",
      "        0.0044, 0.0677])\n",
      "tensor([0.0856, 0.0951, 0.0390, 0.0288, 0.0738, 0.0110, 0.0129, 0.0293, 0.0190,\n",
      "        0.1449, 0.0112, 0.0428, 0.0573, 0.0605, 0.0056, 0.0085, 0.1530, 0.0500,\n",
      "        0.0044, 0.0677]) True\n"
     ]
    }
   ],
   "source": [
    "x = torch.randn(20)\n",
    "\n",
    "a = torch.softmax(x, dim=-1)\n",
    "\n",
    "x1, x2 = x.chunk(2)\n",
    "\n",
    "b1, lse1 = stable_softmax2(x1)\n",
    "b2, lse2 = stable_softmax2(x2)\n",
    "\n",
    "c1 = b1 * torch.exp(lse1) / (torch.exp(lse1) + torch.exp(lse2))\n",
    "c2 = b2 * torch.exp(lse2) / (torch.exp(lse1) + torch.exp(lse2))\n",
    "\n",
    "print(a)\n",
    "print(torch.cat([c1, c2]), torch.allclose(a, torch.cat([c1, c2])))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "But luckily we can rewrite it (`a/(a+b) = 1/(1 + b/a)`) and replace the exp-division by a sbtration of log-values (`exp(a)/exp(b) = exp(a-b)`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.0856, 0.0951, 0.0390, 0.0288, 0.0738, 0.0110, 0.0129, 0.0293, 0.0190,\n",
      "        0.1449, 0.0112, 0.0428, 0.0573, 0.0605, 0.0056, 0.0085, 0.1530, 0.0500,\n",
      "        0.0044, 0.0677])\n",
      "tensor([0.0856, 0.0951, 0.0390, 0.0288, 0.0738, 0.0110, 0.0129, 0.0293, 0.0190,\n",
      "        0.1449, 0.0112, 0.0428, 0.0573, 0.0605, 0.0056, 0.0085, 0.1530, 0.0500,\n",
      "        0.0044, 0.0677])\n",
      "allclose:  True\n"
     ]
    }
   ],
   "source": [
    "d1 = b1 / (1 + torch.exp(lse2 - lse1))\n",
    "d2 = b2 / (1 + torch.exp(lse1 - lse2))\n",
    "print(a)\n",
    "print(torch.cat([d1, d2]))\n",
    "print(\"allclose: \", torch.allclose(a, torch.cat([d1, d2])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the fresh knowledge about softmax we can now take a look at the `update()` function that is used in the ring-flash-attention implementation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _update_out_and_lse(\n",
    "    out: torch.Tensor,\n",
    "    lse: torch.Tensor,\n",
    "    block_out: torch.Tensor,\n",
    "    block_lse: torch.Tensor,\n",
    ") -> Tuple[torch.Tensor, torch.Tensor]:\n",
    "    block_out = block_out.to(torch.float32)\n",
    "    block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)\n",
    "\n",
    "    new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))\n",
    "    out = torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out\n",
    "\n",
    "    lse = new_lse\n",
    "    return out, lse"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
